package com.demo.config.sharding.algorithm;

import com.google.common.collect.Range;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;
import org.springframework.context.annotation.Configuration;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

@Slf4j
@Configuration
public class OrderTableAlgorithm implements StandardShardingAlgorithm<LocalDateTime> {

    private static final DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
    private static final DateTimeFormatter monthFormatter = DateTimeFormatter.ofPattern("yyyyMM");

    /**
     * 获取查询对应分表名
     * @param collection
     * @param preciseShardingValue
     * @return
     */
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<LocalDateTime> preciseShardingValue) {
        LocalDateTime date = preciseShardingValue.getValue();
        if (date == null) {
            return collection.stream().findFirst().get();
        }

        String tableName = preciseShardingValue.getLogicTableName();

        // 如果查询范围包括基础表，则需要联合基础表进行查询
        LocalDateTime minBaseDate = LocalDateTime.parse(StaticValue.userBaseTableMinDate, dateFormatter);
        if (date.isAfter(minBaseDate)) {
            String tableSuffix = date.format(monthFormatter);
            tableName = tableName.concat("_").concat(tableSuffix);
        }

        String t = tableName;
        return collection.stream().filter(str -> str.equals(t)).findFirst().orElseThrow(() -> new RuntimeException(t + "分表不存在"));
    }

    /**
     * 范围查询获取所有分表
     *
     * @param collection
     * @param rangeShardingValue
     * @return 分表集合
     */
    @Override
    public Collection<String> doSharding(Collection collection, RangeShardingValue rangeShardingValue) {

        String logicTableName = rangeShardingValue.getLogicTableName();
        Range<LocalDateTime> valueRange = rangeShardingValue.getValueRange();
        Set<String> tableRange = extracted(logicTableName, valueRange.lowerEndpoint(), valueRange.upperEndpoint());

        return tableRange;
    }



    /**
     * 根据时间范围获取分表集合
     *
     * @param logicTableName
     * @param lowerEndpoint
     * @param upperEndpoint
     * @return
     */
    private Set<String> extracted(String logicTableName, LocalDateTime lowerEndpoint, LocalDateTime upperEndpoint) {
        Set<String> rangeTable = new HashSet<>();

        // 如果查询范围包括基础表，则需要联合基础表进行查询
        LocalDateTime minBaseDate = LocalDateTime.parse(StaticValue.userBaseTableMinDate, dateFormatter);
        LocalDateTime maxBaseDate = LocalDateTime.parse(StaticValue.userBaseTableMaxDate, dateFormatter);

        if (lowerEndpoint.isBefore(minBaseDate)) {
            lowerEndpoint = minBaseDate;
            rangeTable.add(logicTableName);
        }
        if (upperEndpoint.isAfter(maxBaseDate)) {
            throw new RuntimeException("结束时间不在当前时间内");
        }

        // 便利所有分表
        while (lowerEndpoint.isBefore(upperEndpoint)) {
            String tableName = logicTableName.concat("_").concat(lowerEndpoint.format(monthFormatter));
            rangeTable.add(tableName);
            lowerEndpoint = lowerEndpoint.plusMonths(1);
        }

        // 可能开始时间累加后与结束时间一致
        String tableName = logicTableName.concat("_").concat(upperEndpoint.format(monthFormatter));
        rangeTable.add(tableName);

        return rangeTable;
    }

    @Override
    public void init() {

    }

    @Override
    public String getType() {
        return null;
    }

}
